import tqdm
import h5py
import os
import hashlib
import pickle
import numpy as np

from core.components.data_loader import GlobalSplitVideoDataset
from core.data.dmcontrol.dmcontrol_data_loader import DMControlDataset


class RescaleDataset(GlobalSplitVideoDataset):
    """Rescales key in loaded data by subtracting mean and dividing by variance."""

    def compute_stats(self, filenames, name='value'):
        """Computes rescaling statistics for given key across dataset and stores."""
        dump_file = self.get_dump_filename(name)
        if os.path.exists(dump_file):  # preload from file if exists
            print("Loading dataset statistics from {}".format(dump_file))
            with open(dump_file, "rb") as F:
                self._data_mean, self._data_std = pickle.load(F)
        else:
            vals = []
            for filename in tqdm.tqdm(filenames[:min(len(filenames), 10000)]):
                with h5py.File(filename, 'r') as F:
                    vals.extend(F['traj0/' + name][()].tolist())

            vals = np.asarray(vals)
            self._data_mean, self._data_std = vals.mean(), vals.std()
            del vals
            print("Dumping dataset statistics to {}".format(dump_file))
            os.makedirs(self.dump_dir, exist_ok=True)
            with open(dump_file, "wb") as F:
                pickle.dump((self._data_mean, self._data_std), F)

    def rescale_val(self, raw_data, name='discounted_returns'):
        """Rescales data based on extracted statistics."""
        raw_data[name] = (raw_data[name] - self._data_mean) / self._data_std

    def get_dump_filename(self, name):
        """Returns filename for dumped data. Depends on value name and data_dir name."""
        return os.path.join(self.dump_dir, hashlib.md5((self.data_dir + name).encode()).hexdigest() + '.pkl')

    @property
    def dump_dir(self):
        return os.path.join(os.environ['DATA_DIR'], 'temp', 'data_stats')


class DMControlRescaleDataset(DMControlDataset, RescaleDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        rescale_name = self.spec.rescale_data_name if 'rescale_data_name' in self.spec else 'value'
        self.compute_stats(self.filenames, name=rescale_name)
        self._rescale_batch_name = self.spec.rescale_batch_name if 'rescale_batch_name' in self.spec \
            else 'discounted_returns'

    def _load_raw_data(self, data, F):
        super()._load_raw_data(data, F)
        self.rescale_val(data, name=self._rescale_batch_name)
